Skip to content

Upcast gradually when computing variance #4283

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Aug 6, 2025

Conversation

ftynse
Copy link
Member

@ftynse ftynse commented Jul 25, 2025

Going all the way to f64 is undesirable, especially for low-precision tensors in bf16 or f8 variants. Upcast only to the next type, e.g., bf16->f32 or f8->bf16. This is consistent with what Pytorch seems to be doing internally.

Going all the way to f64 is undesirable, especially for low-precision
tensors in bf16 or f8 variants. Upcast only to the next type, e.g.,
bf16->f32 or f8->bf16. This is consistent with what Pytorch seems to be
doing internally.

Signed-off-by: Alex Zinenko <git@ozinenko.com>
@ftynse ftynse force-pushed the users/ftynse/dont-upscale-variance branch from b661b9c to 5b0479c Compare July 30, 2025 13:13
@ftynse ftynse changed the title Don't upscast when computing variance Upcast gradually when computing variance Jul 30, 2025
@ftynse ftynse requested a review from qedawkins July 30, 2025 13:22
@ftynse ftynse marked this pull request as ready for review July 30, 2025 13:22
Copy link
Collaborator

@qedawkins qedawkins left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM as long as CI is happy

@ftynse ftynse merged commit ac4657a into main Aug 6, 2025
3 checks passed
@ftynse ftynse deleted the users/ftynse/dont-upscale-variance branch August 6, 2025 19:51
Comment on lines +9325 to 9334
// Upcasting the input tensor to a double-bitwidth dtype for higher precision
// during the computation of the result.
unsigned bitwidth = inputTensorTy.getDtype().getIntOrFloatBitWidth();
if (bitwidth != 64) {
Type targetTy = rewriter.getF64Type();
if (bitwidth == 8)
targetTy = rewriter.getBF16Type();
else if (bitwidth == 16)
targetTy = rewriter.getF32Type();
self = convertTensorToDtype(rewriter, loc, self, rewriter.getF64Type());
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't you need to replace rewriter.getF64Type() with the targetTy here?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants